{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SqueezeNet v1.1\n",
    "\n",
    "original repo: **https://github.com/DeepScale/SqueezeNet**\n",
    "\n",
    "for keras: **https://github.com/rcmalli/keras-squeezenet** (pretrained imagenet weights downloaded from here)\n",
    "\n",
    "```\n",
    "$ wget -O squeezenet_v1.1.h5 https://github.com/rcmalli/keras-squeezenet/releases/download/v1.0/squeezenet_weights_tf_dim_ordering_tf_kernels.h5\n",
    "```\n",
    "\n",
    "Pre-trained imagenet weights renamed from `squeezenet_weights_tf_dim_ordering_tf_kernels.h5` to `squeezenet_v1.1.hdf5`. Weight names are converted to Keras v2 format, since Keras.js uses these names to load layer weights."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "KERAS_MODEL_FILEPATH = '../../demos/data/squeezenet_v1.1/squeezenet_v1.1.h5'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['conv1_W:0', 'conv1_b:0'] --> ['conv1/kernel:0', 'conv1/bias:0']\n",
      "['fire2/squeeze1x1_W:0', 'fire2/squeeze1x1_b:0'] --> ['fire2/squeeze1x1/kernel:0', 'fire2/squeeze1x1/bias:0']\n",
      "['fire2/expand1x1_W:0', 'fire2/expand1x1_b:0'] --> ['fire2/expand1x1/kernel:0', 'fire2/expand1x1/bias:0']\n",
      "['fire2/expand3x3_W:0', 'fire2/expand3x3_b:0'] --> ['fire2/expand3x3/kernel:0', 'fire2/expand3x3/bias:0']\n",
      "['fire3/squeeze1x1_W:0', 'fire3/squeeze1x1_b:0'] --> ['fire3/squeeze1x1/kernel:0', 'fire3/squeeze1x1/bias:0']\n",
      "['fire3/expand1x1_W:0', 'fire3/expand1x1_b:0'] --> ['fire3/expand1x1/kernel:0', 'fire3/expand1x1/bias:0']\n",
      "['fire3/expand3x3_W:0', 'fire3/expand3x3_b:0'] --> ['fire3/expand3x3/kernel:0', 'fire3/expand3x3/bias:0']\n",
      "['fire4/squeeze1x1_W:0', 'fire4/squeeze1x1_b:0'] --> ['fire4/squeeze1x1/kernel:0', 'fire4/squeeze1x1/bias:0']\n",
      "['fire4/expand1x1_W:0', 'fire4/expand1x1_b:0'] --> ['fire4/expand1x1/kernel:0', 'fire4/expand1x1/bias:0']\n",
      "['fire4/expand3x3_W:0', 'fire4/expand3x3_b:0'] --> ['fire4/expand3x3/kernel:0', 'fire4/expand3x3/bias:0']\n",
      "['fire5/squeeze1x1_W:0', 'fire5/squeeze1x1_b:0'] --> ['fire5/squeeze1x1/kernel:0', 'fire5/squeeze1x1/bias:0']\n",
      "['fire5/expand1x1_W:0', 'fire5/expand1x1_b:0'] --> ['fire5/expand1x1/kernel:0', 'fire5/expand1x1/bias:0']\n",
      "['fire5/expand3x3_W:0', 'fire5/expand3x3_b:0'] --> ['fire5/expand3x3/kernel:0', 'fire5/expand3x3/bias:0']\n",
      "['fire6/squeeze1x1_W:0', 'fire6/squeeze1x1_b:0'] --> ['fire6/squeeze1x1/kernel:0', 'fire6/squeeze1x1/bias:0']\n",
      "['fire6/expand1x1_W:0', 'fire6/expand1x1_b:0'] --> ['fire6/expand1x1/kernel:0', 'fire6/expand1x1/bias:0']\n",
      "['fire6/expand3x3_W:0', 'fire6/expand3x3_b:0'] --> ['fire6/expand3x3/kernel:0', 'fire6/expand3x3/bias:0']\n",
      "['fire7/squeeze1x1_W:0', 'fire7/squeeze1x1_b:0'] --> ['fire7/squeeze1x1/kernel:0', 'fire7/squeeze1x1/bias:0']\n",
      "['fire7/expand1x1_W:0', 'fire7/expand1x1_b:0'] --> ['fire7/expand1x1/kernel:0', 'fire7/expand1x1/bias:0']\n",
      "['fire7/expand3x3_W:0', 'fire7/expand3x3_b:0'] --> ['fire7/expand3x3/kernel:0', 'fire7/expand3x3/bias:0']\n",
      "['fire8/squeeze1x1_W:0', 'fire8/squeeze1x1_b:0'] --> ['fire8/squeeze1x1/kernel:0', 'fire8/squeeze1x1/bias:0']\n",
      "['fire8/expand1x1_W:0', 'fire8/expand1x1_b:0'] --> ['fire8/expand1x1/kernel:0', 'fire8/expand1x1/bias:0']\n",
      "['fire8/expand3x3_W:0', 'fire8/expand3x3_b:0'] --> ['fire8/expand3x3/kernel:0', 'fire8/expand3x3/bias:0']\n",
      "['fire9/squeeze1x1_W:0', 'fire9/squeeze1x1_b:0'] --> ['fire9/squeeze1x1/kernel:0', 'fire9/squeeze1x1/bias:0']\n",
      "['fire9/expand1x1_W:0', 'fire9/expand1x1_b:0'] --> ['fire9/expand1x1/kernel:0', 'fire9/expand1x1/bias:0']\n",
      "['fire9/expand3x3_W:0', 'fire9/expand3x3_b:0'] --> ['fire9/expand3x3/kernel:0', 'fire9/expand3x3/bias:0']\n",
      "['conv10_W:0', 'conv10_b:0'] --> ['conv10/kernel:0', 'conv10/bias:0']\n"
     ]
    }
   ],
   "source": [
    "import h5py\n",
    "\n",
    "hdf5_file = h5py.File(KERAS_MODEL_FILEPATH, mode='r+')\n",
    "f = hdf5_file\n",
    "layer_names = [n.decode('utf8') for n in f.attrs['layer_names']]\n",
    "offset = 0\n",
    "for layer_name in layer_names:\n",
    "    g = f[layer_name]\n",
    "    weight_names = [n.decode('utf8') for n in g.attrs['weight_names']]\n",
    "    if weight_names:\n",
    "        old_weight_names = weight_names\n",
    "        new_weight_names = []\n",
    "        for weight_name in weight_names:\n",
    "            if '_W:' in weight_name:\n",
    "                new_weight_name = weight_name.replace('_W', '/kernel')\n",
    "                g[new_weight_name] = g[weight_name]\n",
    "                del g[weight_name]\n",
    "            elif '_b:' in weight_name:\n",
    "                new_weight_name = weight_name.replace('_b', '/bias')\n",
    "                g[new_weight_name] = g[weight_name]\n",
    "                del g[weight_name]\n",
    "            else:\n",
    "                new_weight_name = weight_name\n",
    "            new_weight_names.append(new_weight_name)\n",
    "        g.attrs['weight_names'] = [n.encode('utf8') for n in new_weight_names]\n",
    "        print(old_weight_names, '-->', new_weight_names)\n",
    "        \n",
    "hdf5_file.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n",
      "/home/leon/miniconda3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: compiletime version 3.5 of module 'tensorflow.python.framework.fast_tensor_util' does not match runtime version 3.6\n",
      "  return f(*args, **kwds)\n"
     ]
    }
   ],
   "source": [
    "from keras.layers import Input, Conv2D, MaxPooling2D, Activation, concatenate, Dropout, GlobalAveragePooling2D\n",
    "from keras.models import Model\n",
    "\n",
    "sq1x1 = \"squeeze1x1\"\n",
    "exp1x1 = \"expand1x1\"\n",
    "exp3x3 = \"expand3x3\"\n",
    "relu = \"relu_\"\n",
    "\n",
    "def fire_module(x, fire_id, squeeze=16, expand=64):\n",
    "    s_id = 'fire' + str(fire_id) + '/'\n",
    "\n",
    "    channel_axis = 3\n",
    "    \n",
    "    x = Conv2D(squeeze, (1, 1), padding='valid', name=s_id + sq1x1)(x)\n",
    "    x = Activation('relu', name=s_id + relu + sq1x1)(x)\n",
    "\n",
    "    left = Conv2D(expand, (1, 1), padding='valid', name=s_id + exp1x1)(x)\n",
    "    left = Activation('relu', name=s_id + relu + exp1x1)(left)\n",
    "\n",
    "    right = Conv2D(expand, (3, 3), padding='same', name=s_id + exp3x3)(x)\n",
    "    right = Activation('relu', name=s_id + relu + exp3x3)(right)\n",
    "\n",
    "    x = concatenate([left, right], axis=channel_axis, name=s_id + 'concat')\n",
    "    return x\n",
    "\n",
    "\n",
    "input_shape = (227, 227, 3)\n",
    "classes = 1000\n",
    "\n",
    "img_input = Input(shape=input_shape)\n",
    "\n",
    "x = Conv2D(64, (3, 3), strides=(2, 2), padding='valid', name='conv1')(img_input)\n",
    "x = Activation('relu', name='relu_conv1')(x)\n",
    "x = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), name='pool1')(x)\n",
    "\n",
    "x = fire_module(x, fire_id=2, squeeze=16, expand=64)\n",
    "x = fire_module(x, fire_id=3, squeeze=16, expand=64)\n",
    "x = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), name='pool3')(x)\n",
    "\n",
    "x = fire_module(x, fire_id=4, squeeze=32, expand=128)\n",
    "x = fire_module(x, fire_id=5, squeeze=32, expand=128)\n",
    "x = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), name='pool5')(x)\n",
    "\n",
    "x = fire_module(x, fire_id=6, squeeze=48, expand=192)\n",
    "x = fire_module(x, fire_id=7, squeeze=48, expand=192)\n",
    "x = fire_module(x, fire_id=8, squeeze=64, expand=256)\n",
    "x = fire_module(x, fire_id=9, squeeze=64, expand=256)\n",
    "x = Dropout(0.5, name='drop9')(x)\n",
    "\n",
    "x = Conv2D(classes, (1, 1), padding='valid', name='conv10')(x)\n",
    "x = Activation('relu', name='relu_conv10')(x)\n",
    "x = GlobalAveragePooling2D()(x)\n",
    "out = Activation('softmax', name='loss')(x)\n",
    "\n",
    "model = Model(inputs=img_input, outputs=out, name='squeezenet')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.load_weights(KERAS_MODEL_FILEPATH)\n",
    "model.compile(optimizer='adam', loss='categorical_crossentropy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save(KERAS_MODEL_FILEPATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
